#!/usr/bin/env python
# coding: utf-8
#current

# In[1]:


import numpy as np
from copy import deepcopy
from pomdp import constrained_pomdp
from itertools import product


# In[2]:


def probsum(iterable):
    sum = 0
    for i,j in enumerate(iterable):
        sum = sum + j

    return sum


# In[3]:


def man_dist(a,b):
    return abs(a[0]-b[0]) + abs(a[1]-b[1])


# In[4]:


def gridworld(m,n,p_obstacle_locations,obstacle_location,horizon,p,qr_n,qr_a,ql_n,ql_a,r,rew):
    n_o = len(p_obstacle_locations)
    o_dict = {p_obstacle_locations[i]:i for i in range(n_o)}
    
    
    states = {t:([((i,j),o)  for i in range(m) for j in range(n) for o in p_obstacle_locations ]) for t in range(1,horizon+1)}
    actions = {t:['U','D','L','R'] for t in range(1,horizon+1)}
    
    P = {}
    for s,o in states[1]:
        for a in actions[1]:
            P.update({(s,o,a):[[],[]]})
    
    for s,o in states[1]:
        for a in actions[1]:
            if a!='U' and s[0]+1 < m:
                P[(s,o,a)][0].append(((s[0]+1,s[1]),o))
                P[(s,o,a)][1].append(p if a =='D' else (1-p)/2.0)

            if a!='D' and s[0]-1 >= 0:
                P[(s,o,a)][0].append(((s[0]-1,s[1]),o))
                P[(s,o,a)][1].append(p if a =='U' else (1-p)/2.0)

            if a!='R' and s[1]-1 >= 0:
                P[(s,o,a)][0].append(((s[0],s[1]-1),o))
                P[(s,o,a)][1].append(p if a =='L' else (1-p)/2.0)

            if a!='L' and s[1]+1 < n:
                P[(s,o,a)][0].append(((s[0],s[1]+1),o))
                P[(s,o,a)][1].append(p if a =='R' else (1-p)/2.0)

            prob_sum = probsum(P[(s,o,a)][1])

            if prob_sum<1:
                P[(s,o,a)][0].append((s,o))
                P[(s,o,a)][1].append(1-prob_sum)
                
    trans_step = dict()
    for a in actions[1]:
        trans_step.update({a:np.zeros((m*n*n_o,m*n*n_o))})
        for s,o in states[1]:
            x,y = s[0],s[1]
            for i,(ns,no) in enumerate(P[(s,o,a)][0]):
                nx,ny = ns[0],ns[1]
                trans_step[a][(n*x+y)*n_o + o_dict[o]][(n*nx+ny)*n_o + o_dict[no]] = P[(s,o,a)][1][i]
    
    transitions = {t: deepcopy(trans_step) for t in range(1,horizon)}
    
    initial_dist = np.zeros(m*n*n_o)
    
    for i in range(n_o):
        initial_dist[i] = 1/n_o
    
    
    observations = {t:([((i,j),o)  for i in range(m) for j in range(n) for o in ['C','F'] ]) for t in range(1,horizon+1)}
    
    constraints = {}
    constraint_val = {}
    constraint_indices = []
    
    rewards = {t:{a:np.zeros(m*n*n_o,dtype=float) for a in actions[t]} for t in states}

    for t in states:
        for a in actions[t]:
            for i in range(m*n):
                for o in range(n_o):
                    rewards[t][a][i*n_o + o] = rew[i]
    
    
    
    obs_prob = np.zeros((m*n*n_o,m*n*2))
    
    for s,o in states[1]:
        obs_dist = man_dist(s,o)
        
        x,y = s[0],s[1]
        count = 0
        sur = set()
        
        for dx in [-1,0,1]:
            for dy in [-1,0,1]:
                if (0 <= x+dx < m) and (0 <= y+dy < n) and (dx+dy != 0):
                    sur.add((x+dx,y+dy))
                    count+=1
        
        if o[1] >= n/2:
            q_n = qr_n
            q_a = qr_a
        else:
            q_n = ql_n
            q_a = ql_a

        if obs_dist <= 1:
            obs_prob[(n*x+y)*n_o + o_dict[o]][(n*x+y)*2] = r*q_n
            obs_prob[(n*x+y)*n_o + o_dict[o]][(n*x+y)*2 + 1] = r*(1-q_n)
        else:
            obs_prob[(n*x+y)*n_o + o_dict[o]][(n*x+y)*2] = r*(1-q_a)
            obs_prob[(n*x+y)*n_o + o_dict[o]][(n*x+y)*2 + 1] = r*q_a
        
        for nx,ny in sur:
            if (nx,ny) == s:
                print('Error')
            if obs_dist <= 1:
                obs_prob[(n*x+y)*n_o + o_dict[o]][(n*nx+ny)*2] = ((1-r)/count)*q_n
                obs_prob[(n*x+y)*n_o + o_dict[o]][(n*nx+ny)*2 + 1] = ((1-r)/count)*(1-q_n)
            else:
                obs_prob[(n*x+y)*n_o + o_dict[o]][(n*nx+ny)*2] = ((1-r)/count)*(1-q_a)
                obs_prob[(n*x+y)*n_o + o_dict[o]][(n*nx+ny)*2 + 1] = ((1-r)/count)*q_a
    
    observation_probability = {t:deepcopy(obs_prob) for t in states}
    
    return constrained_pomdp(initial_dist,states,actions,transitions,observations,observation_probability,rewards,constraints,constraint_val,constraint_indices,horizon)

